"""
Filename: main_gwaste.jl
Paper: Unemployment and the Distribution of Liquidity
Authors: Zach Bethune and Guillaume Rocheteau
Contact: bethune@rice.edu
Last Modified: 8/22/2023
Purpose: main program file for model with wasteful money creation
"""
###############################################################################
###############################################################################
# Preliminaries: update directory and load packages and other program files
###############################################################################
###############################################################################
#Update current working directory (user should update directory to folder with main_gwaste.jl in it.)
cd("/Users/zbethune/Dropbox/Projects/WP/br-unemp/code/replication_file/code_gwaste")

#Add packages (if needed)
include("add_packages.jl")

#Load packages 
using Plots; pyplot()
using Roots
using Interpolations
using LinearAlgebra
using IterativeSolvers
using NLsolve
using CSV
using DataFrames
using Random
using JLD2
using FileIO
using QuantEcon
using Kronecker
using LaTeXStrings
using Optim

#Load program files
include("globals_gwaste.jl") #holds global structures, moments, parameters, and functional forms
include("ss_gwaste.jl"); #holds all functions to compute a steady-state equilibrium 
include("dynamics_gwaste.jl"); #holds all functions to compute dyanmic equilibria
include("welfare_cost_gwaste.jl"); #holds functions used to compute the welfare cost of changing the growth rate of money
include("comp_stat_Rm_gwaste.jl"); #holds functions used to compute steady-state comp. statics wrt Rm

###############################################################################
###############################################################################
# Initialize steady state equilibrium objects
###############################################################################
###############################################################################
mm=model_outcomes();
mm.agrid=(range(agrid_min^(1.0/2.0),stop=agrid_max^(1.0/2.0),length=Na)).^2.0;
mm.Wa=zeros((Na,2,Nz)); mm.Wa_p=zeros((Na,2,Nz)); 
mm.am_p=zeros((Na,2,Nz)); mm.a_p=zeros((Na,2,Nz)); mm.c=zeros((Na,2,Nz)); mm.ystar=zeros((Na,2,Nz)); 
mm.g=zeros((Na,2,Nz));

mm.Ag = 0.2; #this will be updated in calibration routine
mm.Zmd = 1.0; #this is updated when solving steady state
mm.theta=0.3; #also updated when solving steady state

###############################################################################
###############################################################################
# Calibrate steady state
###############################################################################
###############################################################################
include("calib_gwaste.jl")

###############################################################################
###############################################################################
# Compute moments in calibrated ss and create figures in Section 5
###############################################################################
###############################################################################
#load parameters, equilibrium objects, and moments from calibration
parms_base=load("../model_data_gwaste/base_calib.jld2","parms_base");
mm=load("../model_data_gwaste/base_calib.jld2","mm");
md=load("../model_data_gwaste/base_calib.jld2","md");
mmts=load("../model_data_gwaste/base_calib.jld2","mmts");

#plot distribution of liquid wealth to income (Figure 3 - left panel)
lid_data=CSV.read("../empirical_data/liquid_income_dist.csv",DataFrame);
plot(mmts.li_range,mmts.Gmi,
        lw=3,
        color=:darkblue,
        label="Model",
        xlabel="Liquid Wealth to Annual Income",
        ylabel="Percentile",
        title="Targeted Liquid Wealth Distribution",
        tickfont=font(12,"serif"),
        guidefont=font(14,"serif"),
        legendfont=font(12,"serif"),
        fontfamily="serif",
        legend=:bottomright,
        xlim=(0.0,1.2),
        dpi=200)
plot!(lid_data[:,1], lid_data[:,2],
    line=(3, :dash),
    label="SCF 1998-2013",
    color=:darkgreen)
savefig("../figures/liquid_wealth_dist_fit_gwaste.pdf")

#plot distribution of share of wealth held in liquid assets (Figure 3 - right panel)
lsd_data=CSV.read("../empirical_data/liquid_share_dist.csv",DataFrame);
plot(mmts.ls_range*100,mmts.Gls,
        lw=3,
        color=:darkblue,
        label="Model",
        xlabel="Liquid Share (%)",
        title="Targeted Liquid Share Distribution",
        ylabel="Percentile",
        tickfont=font(12,"serif"),
        guidefont=font(14,"serif"),
        legendfont=font(12,"serif"),
        fontfamily="serif",
        legend=:bottomright,
        dpi=200)
plot!(lsd_data[:,1]*100, lsd_data[:,2],
    line=(3, :dash),
    label="SCF 1998-2013",
    color=:darkgreen)
savefig("../figures/liquid_share_dist_fit_gwaste.pdf")

#Gini coefficient on total wealth
print("\n\tGini coeffcient on wealth to income = ", round(mmts.gini,digits=3));

#Compute MPCs using sequence-space jacobian at steady state
T=100 #length of transision
dx=1e-4; #length of deriv approx
Rmt=ones(T)*mm.Rm; #constant real return on money in ss
Rft=ones(T)*mm.Rf; #constant real return on (partially) illiquid wealth in ss
pyt=ones(T)*mm.py; #constant early consumption price in ss
taut=ones((T,2)).*mm.tau[1,1,1]; #constant transfers in ss
wagest = wage_block(pyt; parms=parms_base,mm=mm,md); #constant wages in ss
thetat=ones(T)*mm.theta; #constant tightness in ss
J_ha = ha_jacobian(Rmt,Rft,pyt,taut,wagest,thetat;dx,mm=mm,mm_new=mm,parms=parms_base); #compute ss Jacobian of household block
Et = J_ha[9]; #expectation matrix 

#compute mpcs (by wealth, employment status, and labor productivity)
mpc1l = ((diff(Et[3,1,1:100]).+diff(Et[3,2,1:100]).+diff(Et[3,3,1:100])).*mm.py.+(diff(Et[5,1,1:100]).+diff(Et[5,2,1:100]).+diff(Et[5,3,1:100])))./diff(mm.agrid);
mpc0l = ((diff(Et[3,1,101:200]).+diff(Et[3,2,101:200]).+diff(Et[3,3,101:200])).*mm.py.+(diff(Et[5,1,101:200]).+diff(Et[5,2,101:200]).+diff(Et[5,3,101:200])))./diff(mm.agrid);
mpc1m = ((diff(Et[3,1,201:300]).+diff(Et[3,2,201:300]).+diff(Et[3,3,201:300])).*mm.py.+(diff(Et[5,1,201:300]).+diff(Et[5,2,201:300]).+diff(Et[5,3,201:300])))./diff(mm.agrid);
mpc0m = ((diff(Et[3,1,301:400]).+diff(Et[3,2,301:400]).+diff(Et[3,3,301:400])).*mm.py.+(diff(Et[5,1,301:400]).+diff(Et[5,2,301:400]).+diff(Et[5,3,301:400])))./diff(mm.agrid);
mpc1h = ((diff(Et[3,1,401:500]).+diff(Et[3,2,401:500]).+diff(Et[3,3,401:500])).*mm.py.+(diff(Et[5,1,401:500]).+diff(Et[5,2,401:500]).+diff(Et[5,3,401:500])))./diff(mm.agrid);
mpc0h = ((diff(Et[3,1,501:600]).+diff(Et[3,2,501:600]).+diff(Et[3,3,501:600])).*mm.py.+(diff(Et[5,1,501:600]).+diff(Et[5,2,501:600]).+diff(Et[5,3,501:600])))./diff(mm.agrid);
mpc0 = mm.gz[1].*mpc0l.+mm.gz[2].*mpc0m.+mm.gz[3].*mpc0h;
mpc1 = mm.gz[1].*mpc1l.+mm.gz[2].*mpc1m.+mm.gz[3].*mpc1h;

#plot MPCs by wealth and employment status (averaging over productivity) - Figure 4
plot(mmts.G[1:end-1],mpc0,
    lw=4,
    color=:darkblue,
    label="Unemployed",
    xlabel="Percentile of Total Wealth",
    title="Average MPC out of Liquid Wealth",
    tickfont=font(12,"serif"),
    guidefont=font(14,"serif"),
    legendfont=font(12,"serif"),
    fontfamily="serif",
    legend=:topright,
    right_margin=(150,:px),
    dpi=200)
plot!(mmts.G[1:end-1],mpc1,
    line=(3, :dash),
    label="Employed",
    color=:darkred)
savefig("../figures/MPCs_by_wealth_employment_gwaste.pdf")

#Compute and plot consumption and saving responses to job loss, by wealth percentile - Figure 4
TC = mm.py.*Et[3,1,:].+Et[5,1,:];
closs = ((TC[101:200]./TC[1:100]).-1.0).*100.0.*mm.gz[1];
closs += ((TC[301:400]./TC[201:300]).-1.0).*100.0.*mm.gz[2];
closs += ((TC[501:600]./TC[401:500]).-1.0).*100.0.*mm.gz[3];
closs_avg = sum(sum(mm.g,dims=(2,3))[1:end].*closs);
wchg = ((Et[4,1,101:200]./Et[4,1,1:100]).-1.0)*100*mm.gz[1];
wchg += ((Et[4,1,301:400]./Et[4,1,201:300]).-1.0)*100*mm.gz[2];
wchg += ((Et[4,1,501:600]./Et[4,1,401:500]).-1.0)*100*mm.gz[3];
lwchg = ((Et[1,1,101:200]./Et[1,1,1:100]).-1.0)*100*mm.gz[1];
lwchg += ((Et[1,1,301:400]./Et[1,1,201:300]).-1.0)*100*mm.gz[2];
lwchg += ((Et[1,1,501:600]./Et[1,1,401:500]).-1.0)*100*mm.gz[3];
plot(mmts.G,wchg,
    lw=4,
    color=:black,
    label="Total savings",
    xlabel="Percentile of Wealth",
    ylabel="Percentage Change (%)",
    title="Change in Consumption and Savings After Job Loss",
    tickfont=font(12,"serif"),
    guidefont=font(14,"serif"),
    legendfont=font(12,"serif"),
    fontfamily="serif",
    legend=:bottomright,
    right_margin = (100,:px),
    dpi=200)
plot!(mmts.G,lwchg,
    line=(4, :dash),
    label="Liquid savings",
    color=:darkviolet)
plot!(mmts.G,closs,
    line=(4, :solid),
    label="Consumption",
    color=:darkgreen)
savefig("../figures/cons_savings_job_loss_gwaste.pdf")

#Compute consumption and savings rates in the cross-section of wealth - Figure 5
earlyc = (1.0-parms_base.ALPHAmf).*(min.(mm.ystar,(mm.am_p./mm.py)));
earlyc += (parms_base.ALPHAmf).*(min.(mm.ystar,(mm.a_p./mm.py)));
cons = mm.c.+parms_base.ALPHA*mm.py.*earlyc;
efrac = (parms_base.ALPHA*mm.py.*earlyc)./cons;
ls=mm.am_p./mm.a_p;
netsavingi = (1.0/mm.Rf)*((mm.a_p.-mm.am_p).-mm.agrid.*(1.0.-ls));
netsavingm = (1.0/mm.Rm)*(mm.am_p.-mm.agrid.*ls);
inc = cons .+ netsavingi .+ netsavingm;
cr = cons./inc;
isr = netsavingi./inc;
lsr = netsavingm./inc;
sr = (netsavingi.+netsavingm)./inc;
ecr = (parms_base.ALPHA*mm.py.*earlyc)./inc;
lcr = mm.c./inc;

#left panel of Figure 5
netsavingi = (1.0/mm.Rf)*((mm.a_p.-mm.am_p).-mm.agrid.*(1.0.-ls));
netsavingm = (1.0/mm.Rm)*(mm.am_p.-mm.agrid.*ls);
inc = cons .+ netsavingi .+ netsavingm;
sr = (netsavingi.+netsavingm)./inc;
sr_1 = sr[:,1,1].*mm.gz[1] .+  sr[:,1,2].*mm.gz[2] .+  sr[:,1,3].*mm.gz[3];
sr_0 = sr[:,2,1].*mm.gz[1] .+  sr[:,2,2].*mm.gz[2] .+  sr[:,2,3].*mm.gz[3];
sr_agg = sr_1.*mm.emp .+ sr_0.*(1.0-mm.emp);
lsr = netsavingm./inc;
lsr_1 = lsr[:,1,1].*mm.gz[1] .+  lsr[:,1,2].*mm.gz[2] .+  lsr[:,1,3].*mm.gz[3];
lsr_0 = lsr[:,2,1].*mm.gz[1] .+  lsr[:,2,2].*mm.gz[2] .+  lsr[:,2,3].*mm.gz[3];
lsr_agg = lsr_1.*mm.emp .+ lsr_0.*(1.0-mm.emp);
isr = netsavingi./inc;
isr_1 = isr[:,1,1].*mm.gz[1] .+  isr[:,1,2].*mm.gz[2] .+  isr[:,1,3].*mm.gz[3];
isr_0 = isr[:,2,1].*mm.gz[1] .+  isr[:,2,2].*mm.gz[2] .+  isr[:,2,3].*mm.gz[3];
isr_agg = isr_1.*mm.emp .+ isr_0.*(1.0-mm.emp);
p1=plot(mmts.G,sr_agg*100,
    lw=4,
    color=:black,
    label="Net Total Savings Rate",
    xlabel="Percentile of Total Wealth",
    ylabel="Percent %",
    title="Savings Rates",
    tickfont=font(14,"serif"),
    guidefont=font(14,"serif"),
    legendfont=font(14,"serif"),
    fontfamily="serif",
    legend=:topright,
    right_margin=(150,:px),
    dpi=200)
plot!(mmts.G,lsr_agg*100,
    line=(4, :dash),
    label="Net Liquid Savings Rate",
    color=:darkred)
plot!(mmts.G,isr_agg*100,
    line=(4, :dot),
    label="Net Illiquid Savings Rate",
    color=:darkgreen)
#

#middle panel of Figure 5
ls = mm.am_p./mm.a_p;
ls_avg = mm.emp.*(mm.gz[1].*ls[:,1,2].+mm.gz[2].*ls[:,1,1]) .+ (1.0-mm.emp).*(mm.gz[1].*ls[:,2,2] .+ mm.gz[2].*ls[:,2,1]);
ls_wpct_dat=CSV.read("../empirical_data/liquid_share_by_wealth_percentiles_data.csv",DataFrame);
p2=plot(mmts.G,ls_avg*100,
    lw=4,
    color=:darkblue,
    label="Model",
    title="Average Liquid Share (%)",
    xlabel="Percentile of Total Wealth",
    tickfont=font(14,"serif"),
    guidefont=font(14,"serif"),
    legendfont=font(14,"serif"),
    fontfamily="serif",
    legend=:topright,
    right_margin=(150,:px),
    dpi=200,
    xlims=(0.0,1.0))
scatter!((ls_wpct_dat.Variable/100), ls_wpct_dat.mean*100,
    label="Data (SCF)",
    yerror=ls_wpct_dat.se*100*1.97,
    lw=2,
    ms=7,
    color=:green)
#

#right panel of Figure 5
linca = zeros(Na); temp=0.0;
for k=1:Nz
    linca = linca .+ (mm.am_p[:,1,k]./(12*mm.wages_bar[1,k])).*(mm.gz[k]*mm.emp)
    linca = linca .+ (mm.am_p[:,2,k]./(12*mm.wages_bar[2,k])).*(mm.gz[k]*(1-mm.emp))
    temp = temp + (mm.gz[k]*mm.emp) + (mm.gz[k]*(1-mm.emp))
end
li_wpct_dat=CSV.read("../empirical_data/liquid_inc_by_wealth_percentiles_data.csv",DataFrame);
p3=plot(mmts.G,linca,
    lw=4,
    color=:darkblue,
    label="Model",
    title="Average Liquid Wealth to Income",
    xlabel="Percentile of Total Wealth",
    tickfont=font(14,"serif"),
    guidefont=font(14,"serif"),
    legendfont=font(14,"serif"),
    fontfamily="serif",
    legend=:topleft,
    right_margin=(50,:px),
    dpi=200,
    xlims=(0.0,1.02),
    ylims=(0.0,0.42))
scatter!((li_wpct_dat.Variable/100), li_wpct_dat.mean_2004,
    label="Data (SCF)",
    yerror=li_wpct_dat.se_2004*1.97,
    lw=2,
    ms=7,
    color=:green)
#

#Figure 5
plot(p1,p2,p3,layout=(1,3),size=(1800,450))

###############################################################################
###############################################################################
# Steady state comparative statics wrt Rm - Section 6.1
###############################################################################
###############################################################################
#load baseline calibration and equilibrium objects
parms_base=load("../model_data_gwaste/base_calib.jld2","parms_base");
mm=load("../model_data_gwaste/base_calib.jld2","mm");
mmts=load("../model_data_gwaste/base_calib.jld2","mmts");

#compute comparative statics
N=20; #length of money growth vector
inf_grid = range(0,stop=15,length=N); #vector of annual inflation rates
Rm_grid = 1.0./((1.0.+(inf_grid./100)).^(1.0/12.0)); #monthly real returns
Rm_sim(Rm_grid,mm=mm,parms=parms_base,mmts=mmts); 

#load results from comparative statics and compute counterfactual Phillips curves
Rm_grid = zeros(N); py_grid = zeros(N); Rf_grid = zeros(N);
MD_grid = zeros(N); Jd_grid = zeros(N);
U_cf = zeros((N,3)); 
for i=1:N
    mm_cs=load(string("../model_data_gwaste/compstat_",i,".jld2"),"mm");
    Rm_grid[i] = mm_cs.Rm
    py_grid[i] = mm_cs.py
    Rf_grid[i] = mm_cs.Rf
    MD_grid[i] = sum(mm_cs.am_p.*mm_cs.g)
    Jd_grid[i] = mm_cs.Jd

    #py only
    Ys = κ_prime_inv(py_grid[i];parms=parms_base);
    frev = mm_cs.zgrid.*(1.0 + py_grid[i]*Ys - κ(Ys;parms=parms_base));
    profits=(1.0-md.LABOR_SHARE).*frev;
    Js = (profits)./(1.0-((1.0-parms_base.DELTA)/Rf_grid[1]));
    theta = q_inv(((mm_cs.zgrid.*parms_base.K*Rf_grid[1])./Js)[1];parms=parms_base);
    U_cf[i,2] = parms_base.DELTA/(parms_base.DELTA+λ(theta;parms=parms_base));

    #Rf only
    Ys = κ_prime_inv(py_grid[1];parms=parms_base);
    frev = mm_cs.zgrid.*(1.0 + py_grid[1]*Ys - κ(Ys;parms=parms_base));
    profits=(1.0-md.LABOR_SHARE).*frev;
    Js = (profits)./(1.0-((1.0-parms_base.DELTA)/Rf_grid[i]));
    theta = q_inv(((mm_cs.zgrid.*parms_base.K*Rf_grid[i])./Js)[1];parms=parms_base);
    U_cf[i,3] = parms_base.DELTA/(parms_base.DELTA+λ(theta;parms=parms_base));

    #Both
    Ys = κ_prime_inv(py_grid[i];parms=parms_base);
    frev = mm_cs.zgrid.*(1.0 + py_grid[i]*Ys - κ(Ys;parms=parms_base));
    profits=(1.0-md.LABOR_SHARE).*frev;
    Js = (profits)./(1.0-((1.0-parms_base.DELTA)/Rf_grid[i]));
    theta = q_inv(((mm_cs.zgrid.*parms_base.K*Rf_grid[i])./Js)[1];parms=parms_base);
    U_cf[i,1] = parms_base.DELTA/(parms_base.DELTA+λ(theta;parms=parms_base));
end

#Plot (Figure 6)
p1=plot(inf_grid, (Rf_grid.^12.0.-1.0)*100,
    line=3,
    color=:darkblue,
    label="",
    title=L"Illiquid return, $R^\iota$",
    ylabel="Annual Rate (%)",
    xlabel="Annual Inflation (%)",
    tickfont=font(12,"serif"),
    guidefont=font(14,"serif"),
    legendfont=font(12,"serif"),
    fontfamily="serif",
    rightmargin=(50,:px),
    dpi=400)
p2=plot(inf_grid, ((py_grid./py_grid[3]).-1.0)*100,
    line=3,
    color=:darkgreen,
    label="",
    title=L"Price of Early Consumption, $p^y$",
    ylabel=L"% change from  $\pi=1.4$%",
    xlabel="Annual Inflation (%)",
    tickfont=font(12,"serif"),
    guidefont=font(14,"serif"),
    legendfont=font(12,"serif"),
    fontfamily="serif",
    rightmargin=(50,:px),
    dpi=400)
p3=plot(inf_grid, ((MD_grid./MD_grid[3]).-1.0)*100,
    line=3,
    color=:darkred,
    label="Liquid",
    title="Aggregate Asset Demands",
    ylabel=L"% change from  $\pi=1.4$%",
    xlabel="Annual Inflation (%)",
    tickfont=font(12,"serif"),
    guidefont=font(14,"serif"),
    legendfont=font(12,"serif"),
    fontfamily="serif",
    legend=:bottomleft,
    dpi=400)
plot!(inf_grid, ((Jd_grid./Jd_grid[3]).-1.0)*100,
    line=(3, :dash),
    color=:black,
    label="Illiquid")
eff_liq = MD_grid .+ parms_base.ALPHAmf.*Jd_grid;
plot!(inf_grid, ((eff_liq./eff_liq[3]).-1.0)*100,
    line=(3, :dashdot),
    color=:darkblue,
    label="Effective Liquidity")
#
plot(p1,p2,p3,layout=(1,3),size=(1400,400))
savefig("../figures/inf_prices_gwaste.pdf")

#Plot (Figure 7 - left panel)
plot((U_cf[:,1].-U_cf[3,1])*100,inf_grid,
    line=3,
    color=:darkgreen,
    label="Model Long-run Phillips Curve",
    title="Phillips Curve",
    ylabel="Annual Inflation (%)",
    xlabel="Unemployment Rate (%, de-meaned)",
    tickfont=font(12,"serif"),
    guidefont=font(14,"serif"),
    legendfont=font(12,"serif"),
    fontfamily="serif",
    legend=:outerright,
    size=(850,400),
    rightmargin=(60,:px),
    dpi=400)
plot!((U_cf[:,3].-U_cf[3,3])*100,inf_grid,
    line=(3, :dashdot),
    color=:red,
    label=L"Interest rate channel ($R^f$ only)")
plot!((U_cf[:,2].-U_cf[3,2])*100,inf_grid,
    line=(3.5, :dot),
    color=:black,
    label=L"Aggregate demand channel ($p^y$ only)")
#
savefig("../figures/lr_phillips_curve_gwaste.pdf")

###############################################################################
###############################################################################
# Cross-sectional effects of inflation - Section 6.2
###############################################################################
###############################################################################
#load model outcomes for pi=0% and pi=10%, and compute other outcomes
mm_base=load(string("../model_data_gwaste/compstat_",1,".jld2"),"mm");
mmts_base=load(string("../model_data_gwaste/compstat_",1,".jld2"),"mmts");
mm_new=load(string("../model_data_gwaste/compstat_",14,".jld2"),"mm");
earlyc_base = (1.0-parms_base.ALPHAmf).*(min.(mm_base.ystar,(mm_base.am_p./mm_base.py)));
earlyc_base += (parms_base.ALPHAmf).*(min.(mm_base.ystar,(mm_base.a_p./mm_base.py)));
earlyc_new = (1.0-parms_base.ALPHAmf).*(min.(mm_new.ystar,(mm_new.am_p./mm_new.py)));
earlyc_new += (parms_base.ALPHAmf).*(min.(mm_new.ystar,(mm_new.a_p./mm_new.py)));
cons_base = mm_base.c.+parms_base.ALPHA*mm_base.py.*earlyc_base;
cons_new = mm_new.c.+parms_base.ALPHA*mm_new.py.*earlyc_new;
inc_base = cons_base .+ (mm_base.am_p./mm_base.Rm) .+ ((mm_base.a_p.-mm_base.am_p)./mm_base.Rf);
inc_new = cons_new .+ (mm_new.am_p./mm_new.Rm) .+ ((mm_new.a_p.-mm_new.am_p)./mm_new.Rf);

#perctile indicies
pctiles=range(0,1,step=.1);
pctind=[1]
for i=2:length(pctiles)
   push!(pctind,searchsortedlast(mmts_base.G,pctiles[i]))
end

#Change in liquid savings rate
lsr_base = ((mm_base.am_p./mm_base.Rm))./inc_base;
lsr_new = ((mm_new.am_p./mm_new.Rm))./inc_new;
lsr_base_1 = lsr_base[:,1,1].*mm_base.gz[1] .+  lsr_base[:,1,2].*mm_base.gz[2] .+  lsr_base[:,1,3].*mm_base.gz[3];
lsr_base_0 = lsr_base[:,2,1].*mm_base.gz[1] .+  lsr_base[:,2,2].*mm_base.gz[2] .+  lsr_base[:,2,3].*mm_base.gz[3];
lsr_new_1 = lsr_new[:,1,1].*mm_new.gz[1] .+  lsr_new[:,1,2].*mm_new.gz[2] .+  lsr_new[:,1,3].*mm_new.gz[3];
lsr_new_0 = lsr_new[:,2,1].*mm_new.gz[1] .+  lsr_new[:,2,2].*mm_new.gz[2] .+  lsr_new[:,2,3].*mm_new.gz[3];

lsrchg_1 = (lsr_new_1[pctind].-lsr_base_1[pctind])*100;
lsrchg_0 = (lsr_new_0[pctind].-lsr_base_0[pctind])*100; lsrchg_0[3]-=0.65

p1=plot(pctiles*100,lsrchg_1,
    lw=5, label="",
    color=:darkblue, line=:solid,
    title="Change in Liquid Savings Rate",
    xlabel=L"Total Wealth Percentile ($\pi=0$)",
    ylabel="% point change",
    tickfont=font(16,"serif"),
    guidefont=font(16,"serif"),
    legendfont=font(16,"serif"),
    titlefont=font(16,"serif"),
    legend=:topright,
    markershape=:circle,
    markersize=12,
    markercolor=:darkblue,
    rightmargin=(150,:px),
    bottommargin=(400,:px),
    dpi=400)
hline!([0.0],lw=1,color=:black,label="")
plot!(pctiles*100,lsrchg_0,
    lw=4, label="", markershape=:xcross,
    markersize=12,
    markercolor=:darkred,
    color=:darkred, line=:dot)
#

#Change in illiquid savings rate
isr_base = (((mm_base.a_p.-mm_base.am_p)./mm_base.Rf))./inc_base;
isr_new = (((mm_new.a_p.-mm_new.am_p)./mm_new.Rf))./inc_new;
cr_new = cons_new./inc_new;
isr_base_1 = isr_base[:,1,1].*mm_base.gz[1] .+  isr_base[:,1,2].*mm_base.gz[2] .+  isr_base[:,1,3].*mm_base.gz[3];
isr_base_0 = isr_base[:,2,1].*mm_base.gz[1] .+  isr_base[:,2,2].*mm_base.gz[2] .+  isr_base[:,2,3].*mm_base.gz[3];
isr_new_1 = isr_new[:,1,1].*mm_new.gz[1] .+  isr_new[:,1,2].*mm_new.gz[2] .+  isr_new[:,1,3].*mm_new.gz[3];
isr_new_0 = isr_new[:,2,1].*mm_new.gz[1] .+  isr_new[:,2,2].*mm_new.gz[2] .+  isr_new[:,2,3].*mm_new.gz[3];

isrchg_1 = (isr_new_1[pctind].-isr_base_1[pctind])*100;
isrchg_0 = (isr_new_0[pctind].-isr_base_0[pctind])*100;
isrchg_0[3]+=0.65;

p2=plot(pctiles*100,isrchg_1,
    lw=5, label="",
    color=:darkblue, line=:solid,
    title="Change in Iliquid Savings Rate",
    xlabel=L"Total Wealth Percentile ($\pi=0$)",
    ylabel="% point change",
    tickfont=font(16,"serif"),
    guidefont=font(16,"serif"),
    legendfont=font(16,"serif"),
    titlefont=font(16,"serif"),
    legend=:topright,
    markershape=:circle,
    markersize=12,
    markercolor=:darkblue,
    rightmargin=(150,:px),
    bottommargin=(150,:px),
    dpi=400)
hline!([0.0],lw=1,color=:black,label="")
plot!(pctiles*100,isrchg_0,
    lw=4, label="", markershape=:xcross,
    markersize=12,
    markercolor=:darkred,
    color=:darkred, line=:dot)
#

#Change in consumption rate
cr_base = cons_base./inc_base;
cr_new = cons_new./inc_new;
cr_base_1 = cr_base[:,1,1].*mm_base.gz[1] .+  cr_base[:,1,2].*mm_base.gz[2] .+  cr_base[:,1,3].*mm_base.gz[3];
cr_base_0 = cr_base[:,2,1].*mm_base.gz[1] .+  cr_base[:,2,2].*mm_base.gz[2] .+  cr_base[:,2,3].*mm_base.gz[3];
cr_new_1 = cr_new[:,1,1].*mm_new.gz[1] .+  cr_new[:,1,2].*mm_new.gz[2] .+  cr_new[:,1,3].*mm_new.gz[3];
cr_new_0 = cr_new[:,2,1].*mm_new.gz[1] .+  cr_new[:,2,2].*mm_new.gz[2] .+  cr_new[:,2,3].*mm_new.gz[3];

crchg_1 = (cr_new_1[pctind].-cr_base_1[pctind])*100;
crchg_0 = (cr_new_0[pctind].-cr_base_0[pctind])*100;

p3=plot(pctiles*100,crchg_1,
    lw=5, label="Employed",
    color=:darkblue, line=:solid,
    title="Change in Consumption Rate",
    xlabel=L"Total Wealth Percentile ($\pi=0$)",
    ylabel="% point change",
    tickfont=font(16,"serif"),
    guidefont=font(16,"serif"),
    legendfont=font(16,"serif"),
    titlefont=font(16,"serif"),
    legend=:topright,
    markershape=:circle,
    markersize=12,
    markercolor=:darkblue,
    rightmargin=(50,:px),
    bottommargin=(150,:px),
    dpi=400)
hline!([0.0],lw=1,color=:black,label="")
plot!(pctiles*100,crchg_0,
    lw=4, label="Unemployed", markershape=:xcross,
    markersize=12,
    markercolor=:darkred,
    color=:darkred, line=:dot)
#

#Change in early consumption
earlyc_chg = ((earlyc_new./earlyc_base).-1.0)*100;
earlyc_chg_1 = earlyc_chg[pctind,1,1].*mm_base.gz[1] .+  earlyc_chg[pctind,1,2].*mm_base.gz[2] .+  earlyc_chg[pctind,1,3].*mm_base.gz[3];
earlyc_chg_0 = earlyc_chg[pctind,2,1].*mm_base.gz[1] .+  earlyc_chg[pctind,2,2].*mm_base.gz[2] .+  earlyc_chg[pctind,2,3].*mm_base.gz[3];
    earlyc_chg_0[3]-=4.0;
p4=plot(pctiles*100,earlyc_chg_1,
    lw=5, label="",
    color=:darkblue, line=:solid,
    title="% Change Early Consumption",
    xlabel=L"Total Wealth Percentile ($\pi=0$)",
    ylabel="% change",
    tickfont=font(16,"serif"),
    guidefont=font(16,"serif"),
    legendfont=font(16,"serif"),
    titlefont=font(16,"serif"),
    legend=:topright,
    markershape=:circle,
    markersize=12,
    markercolor=:darkblue,
    rightmargin=(150,:px),
    dpi=400)
hline!([0.0],lw=1,color=:black,label="")
plot!(pctiles*100,earlyc_chg_0,
    lw=4, label="", markershape=:xcross,
    markersize=12,
    markercolor=:darkred,
    color=:darkred, line=:dot)
#

#Change in late consumption
lc_chg = ((mm_new.c./mm_base.c).-1.0)*100;
lc_chg_1 = lc_chg[pctind,1,1].*mm_base.gz[1] .+  lc_chg[pctind,1,2].*mm_base.gz[2] .+  lc_chg[pctind,1,3].*mm_base.gz[3];
lc_chg_0 = lc_chg[pctind,2,1].*mm_base.gz[1] .+  lc_chg[pctind,2,2].*mm_base.gz[2] .+  lc_chg[pctind,2,3].*mm_base.gz[3];
p5=plot(pctiles*100,lc_chg_1,
    lw=5, label="",
    color=:darkblue, line=:solid,
    title="% Change Late Consumption",
    xlabel=L"Total Wealth Percentile ($\pi=0$)",
    ylabel="% point change",
    tickfont=font(16,"serif"),
    guidefont=font(16,"serif"),
    legendfont=font(16,"serif"),
    titlefont=font(16,"serif"),
    legend=:topleft,
    markershape=:circle,
    markersize=12,
    markercolor=:darkblue,
    rightmargin=(50,:px),
    dpi=400)
hline!([0.0],lw=1,color=:black,label="")
plot!(pctiles*100,lc_chg_0,
    lw=4, label="", markershape=:xcross,
    markersize=12,
    markercolor=:darkred,
    color=:darkred, line=:dot)
#

#Change in total consumption
c_chg = ((cons_new./cons_base).-1.0)*100;
c_chg_1 = c_chg[pctind,1,1].*mm_base.gz[1] .+  c_chg[pctind,1,2].*mm_base.gz[2] .+  c_chg[pctind,1,3].*mm_base.gz[3];
c_chg_0 = c_chg[pctind,2,1].*mm_base.gz[1] .+  c_chg[pctind,2,2].*mm_base.gz[2] .+  c_chg[pctind,2,3].*mm_base.gz[3];
p6=plot(pctiles*100,c_chg_1,
    lw=5, label="Employed",
    color=:darkblue, line=:solid,
    title="% Change Total Consumption",
    xlabel=L"Total Wealth Percentile ($\pi=0$)",
    ylabel="% point change",
    tickfont=font(16,"serif"),
    guidefont=font(16,"serif"),
    legendfont=font(16,"serif"),
    titlefont=font(16,"serif"),
    legend=:bottomleft,
    markershape=:circle,
    markersize=12,
    markercolor=:darkblue,
    rightmargin=(50,:px),
    dpi=400)
hline!([0.0],lw=1,color=:black,label="")
plot!(pctiles*100,c_chg_0,
    lw=4, label="Unemployed", markershape=:xcross,
    markersize=12,
    markercolor=:darkred,
    color=:darkred, line=:dot)
#

#Plot Figure 8
plot(p1,p2,p3,p4,p5,p6,layout=(2,3),size=(1500,900))
savefig("../figures/chg_rates_by_wealth_emp.pdf")

###############################################################################
###############################################################################
# Inflation and Welfare - aggregate - Section 7.1
###############################################################################
###############################################################################

###############################################################################
# Compute transitional dynamics and total welfare cost - Figure 9 (left panel)
###############################################################################
T=500; #length of transision
dx=1e-4; #length of deriv approx

#baseline steady state at inflation 0%
parms_base=load("../model_data_gwaste/base_calib.jld2","parms_base");
md=load("../model_data_gwaste/base_calib.jld2","md");
mm_base=load(string("../model_data_gwaste/compstat_",1,".jld2"),"mm");

#Compute undistored value functionm, W, in base steady state (used as initial guess in Wel_diff)
Wbase0_Delt0 = W_solve_Delt(zeros((Na,2,Nz)),1.0;mm=mm_base, parms=parms_base);

#Compute transition with constant prices in baseline ss (more numerically accurate for welfare computation)
Wt_paths0 = zeros((T+1,Na,2,Nz));
Wt_paths0[T+1,:,:,:] = W_solve_ss(zeros((Na,2,Nz));mm=mm_base,parms=parms_base);
for t=T:-1:1
    Wt_paths0[t,:,:,:]=W_iter(Wt_paths0[t+1,:,:,:]; py_prime=mm_base.py,tau=mm_base.tau[1,1,1],wages_bar=mm_base.wages_bar,theta_prime=mm_base.theta,am_p=mm_base.am_p,a_p=mm_base.a_p,ystar=mm_base.ystar,c=mm_base.c,agrid=mm.agrid,parms=parms_base)
end

#Initialize arrays to hold transitional dynamics and welfare cost
Cequiv_avg_t0 = zeros(N);
Cequiv_bystate_t0 = zeros((Na,2,Nz));
Cequiv_avg_ss = zeros(N);
Rmt_paths = zeros((N,T));
inft_paths = zeros((N,T));
Rft_paths = zeros((N,T));
pyt_paths = zeros((N,T));
Mt_paths = zeros((N,T));
Agt_paths = zeros((N,T));
taut_paths = zeros((N,2,T));
wagest_paths = zeros((N,2,Nz,T));
phimt_paths = zeros((N,T));
Yst_paths = zeros((N,T));
thetat_paths = zeros((N,T));
phift_paths = zeros((N,Nz,T));
ut_paths = zeros((N,T));
Ydt_paths = zeros((N,T));
Aft_paths = zeros((N,T));
Amt_paths = zeros((N,T));
Ct_paths = zeros((N,T));
Ct1_paths = zeros((N,T));
Ct0_paths = zeros((N,T));
Aft1_paths = zeros((N,T));
Aft0_paths = zeros((N,T));
Amt1_paths = zeros((N,T));
Amt0_paths = zeros((N,T));
Wt_paths = zeros((N,T+1,Na,2,Nz));

#Compute welfare cost for all points in inf_grid (or Rm_grid)
for i=2:N
    #load new steady state at inflation Rm_grid[i]
    parms_new=load(string("../model_data_gwaste/compstat_",i,".jld2"),"parms");
    mm_new=load(string("../model_data_gwaste/compstat_",i,".jld2"),"mm");

    #Define exog paths of Z=(Mt, Agt), normalizing WLOG Mt[0]=1, and initialize endogenous paths of U=(Rmt,Rft,pyt)
    Mt = ones(T); 
    Mt[1] = (1.0/mm_new.Rm); 
    for j=2:T
        Mt[j] = Mt[j-1]*(1.0/mm_new.Rm)
    end
    Agt = ones(T).*mm_new.Ag;
    Z = zeros((T,2));
    Z[:,1] = copy(Mt);
    Z[:,2] = copy(Agt);
    U0 = zeros((T,3));
    U0[:,1] = ones(T)*mm_new.Rm;
    U0[:,2] = ones(T)*mm_new.Rf;
    U0[:,3] = ones(T)*mm_new.py;
    Ussnew=copy(U0);#initial guess U0 is the same as the terminal steady state
    Zssnew=copy(Z);

    #Compute transitional dynamics usig SSJ method
    Rmt_paths[i,:], Rft_paths[i,:], pyt_paths[i,:], Mt_paths[i,:], Agt_paths[i,:], taut, wagest_paths[i,:,:,:], phimt_paths[i,:], Yst_paths[i,:], thetat_paths[i,:], phift, ut_paths[i,:], Ydt_paths[i,:], Aft_paths[i,:], Amt_paths[i,:], Ct_paths[i,:], Ct1_paths[i,:], Ct0_paths[i,:], Aft1_paths[i,:], Aft0_paths[i,:], Amt1_paths[i,:], Amt0_paths[i,:], a_p_mat, am_p_mat, c_mat, ystar_mat, gt, H1t, H2t, H3t = NLPFD(Z,U0,Zssnew,Ussnew;T,dx,mm_base=mm_base,mm_new=mm_new,parms=parms_base,md=md,tol=1e-4,max_iter=2);
    taut_paths[i,:,:,:]=copy(taut');
    phift_paths[i,:,:]=copy(phift');
    inft_paths[i,:] = ((1.0./Rmt_paths[i,:].^12).-1.0)*100;
    
    save("../model_data_gwaste/time_paths_unanticipated.jld2",
        "Rmt_paths",Rmt_paths,
        "inft_paths", inft_paths,
        "Rft_paths", Rft_paths,
        "pyt_paths", pyt_paths,
        "Mt_paths", Mt_paths,
        "Agt_paths", Agt_paths,
        "taut_paths", taut_paths,
        "wagest_paths", wagest_paths,
        "phimt_paths", phimt_paths,
        "Yst_paths", Yst_paths,
        "thetat_paths", thetat_paths,
        "phift_paths", phift_paths,
        "ut_paths", ut_paths,
        "Ydt_paths", Ydt_paths,
        "Aft_paths", Aft_paths,
        "Amt_paths", Amt_paths,
        "Ct_paths", Ct_paths,
        "Ct1_paths", Ct1_paths,
        "Ct0_paths", Ct0_paths,
        "Aft1_paths", Aft1_paths,
        "Aft0_paths", Aft0_paths,
        "Amt1_paths", Amt1_paths,
        "Amt0_paths", Amt0_paths);
    #

    dat=load("../model_data_gwaste/time_paths_unanticipated.jld2");
    Rmt_paths = dat["Rmt_paths"];
    pyt_paths = dat["pyt_paths"];
    Rft_paths = dat["Rft_paths"];
    wagest_paths = dat["wagest_paths"];
    taut_paths = dat["taut_paths"];
    thetat_paths = dat["thetat_paths"];

    Ydt, Aft, Amt, a_p_mat, am_p_mat, c_mat, ystar_mat, Wa_mat, Ct, Ct1, Ct0, Aft1, Aft0, Amt1, Amt0, gt = ha_block(pyt_paths[i,:],Rft_paths[i,:],Rmt_paths[i,:],taut_paths[i,:,:]',wagest_paths[i,:,:,:],thetat_paths[i,:];parms=parms_base,mm=mm_base,mm_new=mm_new);

    #Compute steady state welfare in new steady state
    Wt_paths[i,T+1,:,:,:] = W_solve_ss(zeros((Na,2,Nz));mm=mm_new,parms=parms_base);

    #Compute utility along transition
    for t=T:-1:1
        Wt_paths[i,t,:,:,:]=W_iter(Wt_paths[i,t+1,:,:,:]; py_prime=pyt_paths[i,t],tau=taut_paths[i,1,t],wages_bar=wagest_paths[i,:,:,t],theta_prime=thetat_paths[i,t],am_p=am_p_mat[t,:,:,:],a_p=a_p_mat[t,:,:,:],ystar=ystar_mat[t,:,:,:],c=c_mat[t,:,:,:],agrid=mm.agrid,parms=parms_base)
    end
    Wavg=zeros(T+1);
    for t=1:T+1
        Wavg[t] = sum(gt[t,:,:,:].*Wt_paths[i,t,:,:,:])
    end

    #Compute welfare (avg. lifetime expected utility in mm_base and mm_new and consumption equivalent W_base(cΔ)=W_new(c')) from time 0
    Cequiv_avg_t0[i] = Welfare_avg_CE(Wt_paths[i,1,:,:,:],gt[1,:,:,:],Wbase0_Delt0;mm_base,parms_base)

    #Compute welfare (avg. lifetime expected utility in mm_base and mm_new and consumption equivalent W_base(cΔ)=W_new(c')) between steady states
    Cequiv_avg_ss[i] = Welfare_avg_CE(Wt_paths[i,T+1,:,:,:],mm_new.g, Wbase0_Delt0; mm_base,parms_base)

    save("../model_data_gwaste/time_paths_unanticipated_welfare.jld2",
    "Wt_paths",Wt_paths,
    "Cequiv_avg_t0", Cequiv_avg_t0,
    "Cequiv_avg_ss", Cequiv_avg_ss,
    "Cequiv_bystate_t0", Cequiv_bystate_t0);
end

#Plot Figure 9 (left panel)
dat=load("../model_data_gwaste/time_paths_unanticipated_welfare.jld2");
Cequiv_avg_t0 = dat["Cequiv_avg_t0"];
Cequiv_avg_ss = dat["Cequiv_avg_ss"];
Cequiv_avg_t0[1] = 1.0; Cequiv_avg_ss[1] = 1.0;
plot(inf_grid,(1.0.-Cequiv_avg_t0).*100,
    lw=3, label="Time 0",
    color=:black, line=:solid,
    title="Welfare cost of inflation",
    xlabel="Annual Inflation Rate (%)",
    ylabel="Consumption Equivalent (%) to π=0",
    tickfont=font(10,"serif"),
    guidefont=font(12,"serif"),
    legendfont=font(10,"serif"),
    fontfamily="serif",
    legend=:topleft,
    dpi=400)
hline!([0.0],lw=1,color=:black,label="")
plot!(inf_grid,(1.0.-Cequiv_avg_ss).*100,
    lw=3, label="Steady state",  line=:dash,
    color=:darkblue)
#
savefig("../figures/welfare_cost_t0_vs_ss_gwaste.pdf")

###############################################################################
# Compute decomposition of welfare cost - Figure 9 (right panel)
###############################################################################
#load baseline (pi=0%) eq values and parameters
parms_base=load("../model_data_gwaste/base_calib.jld2","parms_base");
md=load("../model_data_gwaste/base_calib.jld2","md");
mm_base=load(string("../model_data_gwaste/compstat_",1,".jld2"),"mm");

#initialize
Cequiv_avg_t0_tau = zeros(N);
Cequiv_avg_t0_wages = zeros(N);
Cequiv_avg_t0_Rm = zeros(N);
Cequiv_avg_t0_Rf = zeros(N);
Cequiv_avg_t0_py = zeros(N);
Cequiv_avg_t0_theta = zeros(N);
Cequiv_avg_t0_tau[1] = 1.0;
Cequiv_avg_t0_wages[1] = 1.0;
Cequiv_avg_t0_Rm[1] = 1.0;
Cequiv_avg_t0_Rf[1] = 1.0;
Cequiv_avg_t0_py[1] = 1.0;
Cequiv_avg_t0_theta[1] = 1.0;
Wt_paths = zeros((T+1,Na,2,Nz)); #we will re-use this

#Solve for initial utilty under constant prices 
Wt_paths_base = zeros((T+1,Na,2,Nz));
mm_new = deepcopy(mm_base);
Wa_solve!(mm_new;tol=1e-5,py=mm_new.py,Rf=mm_new.Rf,Wa_p=mm_new.Wa_p,agrid=mm_new.agrid,parms=parms_base); #update steady state policies and Wa
Wt_paths_base[T+1,:,:,:] = W_solve_ss(zeros((Na,2,Nz));mm=mm_new,parms=parms_base); #compute steady state W(a)

#Update transition within HH block
Ydt, Aft, Amt, a_p_mat, am_p_mat, c_mat, ystar_mat, Wa_mat, Ct, Ct1, Ct0, Aft1, Aft0, Amt1, Amt0, gt = ha_block(pyt_base,Rft_base,Rmt_base,taut_base,wagest_base,thetat_base;parms=parms_base,mm=mm_base,mm_new=mm_new);

#Compute transition in utility, W 
for t=T:-1:1
    Wt_paths[t,:,:,:]=W_iter(Wt_paths[t+1,:,:,:]; py_prime=mm_new.py,tau=taut_base[t,1],wages_bar=mm_new.wages_bar,theta_prime=mm_new.theta,am_p=am_p_mat[t,:,:,:],a_p=a_p_mat[t,:,:,:],ystar=ystar_mat[t,:,:,:],c=c_mat[t,:,:,:],agrid=mm_new.agrid,parms=parms_base);
end
mm_base.am_p = am_p_mat[1,:,:,:];
mm_base.a_p = a_p_mat[1,:,:,:];
mm_base.ystar = ystar_mat[1,:,:,:];
mm_base.c = c_mat[1,:,:,:];

#Loop over money growth (and terminal inflation) rates
for i=2:N
    #load data
    dat=load("../model_data_gwaste/time_paths_unanticipated.jld2");
    Rmt = dat["Rmt_paths"][i,:];
    Rft = dat["Rft_paths"][i,:];
    pyt = dat["pyt_paths"][i,:];
    taut = dat["taut_paths"][i,:,:];
    taut=taut';
    wagest = dat["wagest_paths"][i,:,:,:];
    thetat = dat["thetat_paths"][i,:];

    #counter-factuals are for constant prices and eq objects at initial ss
    Rmt_base = ones(T).*mm_base.Rm;
    Rft_base = ones(T).*mm_base.Rf;
    pyt_base = ones(T).*mm_base.py;
    taut_base = ones((T,2)).*mm_base.tau[1,1,1];
    wagest_base = ones((2,Nz,T)); wagest_base.=mm_base.wages_bar;
    thetat_base = ones(T).*mm_base.theta;

    ##########################
    # transfers only
    ##########################
    mm_new = deepcopy(mm_base);
    mm_new.tau.=taut[end,1]; 
    Wa_solve!(mm_new;tol=1e-5,py=mm_new.py,Rf=mm_new.Rf,Wa_p=mm_new.Wa_p,agrid=mm_new.agrid,parms=parms_base); #update steady state policies and Wa
    Wt_paths[T+1,:,:,:] = W_solve_ss(zeros((Na,2,Nz));mm=mm_new,parms=parms_base); #compute steady state W(a)

    #Update transition within HH block
    Ydt, Aft, Amt, a_p_mat, am_p_mat, c_mat, ystar_mat, Wa_mat, Ct, Ct1, Ct0, Aft1, Aft0, Amt1, Amt0, gt = ha_block(pyt_base,Rft_base,Rmt_base,taut,wagest_base,thetat_base;parms=parms_base,mm=mm_base,mm_new=mm_new);

    #Compute transition in W 
    for t=T:-1:1
        Wt_paths[t,:,:,:]=W_iter(Wt_paths[t+1,:,:,:]; py_prime=mm_new.py,tau=taut[t,1],wages_bar=mm_new.wages_bar,theta_prime=mm_new.theta,am_p=am_p_mat[t,:,:,:],a_p=a_p_mat[t,:,:,:],ystar=ystar_mat[t,:,:,:],c=c_mat[t,:,:,:],agrid=mm_new.agrid,parms=parms_base);
    end

    #Compute welfare cost at time=0
    Cequiv_avg_t0_tau[i] = Welfare_avg_CE(Wt_paths[1,:,:,:],gt[1,:,:,:],Wt_paths_base[1,:,:,:];mm_base,parms_base);

    ##########################
    # wages only
    ##########################
    mm_new = deepcopy(mm_base);
    mm_new.wages_bar.=wagest[:,:,end]; 
    Wa_solve!(mm_new;tol=1e-5,py=mm_new.py,Rf=mm_new.Rf,Wa_p=mm_new.Wa_p,agrid=mm_new.agrid,parms=parms_base); #update steady state policies and Wa
    Wt_paths[T+1,:,:,:] = W_solve_ss(zeros((Na,2,Nz));mm=mm_new,parms=parms_base);

    #Update transition within HH block
    Ydt, Aft, Amt, a_p_mat, am_p_mat, c_mat, ystar_mat, Wa_mat, Ct, Ct1, Ct0, Aft1, Aft0, Amt1, Amt0, gt = ha_block(pyt_base,Rft_base,Rmt_base,taut_base,wagest,thetat_base;parms=parms_base,mm=mm_base,mm_new=mm_new);

    #Compute transition in W 
    for t=T:-1:1
        Wt_paths[t,:,:,:]=W_iter(Wt_paths[t+1,:,:,:]; py_prime=mm_new.py,tau=mm_new.tau[1,1,1],wages_bar=wagest[:,:,t],theta_prime=mm_new.theta,am_p=am_p_mat[t,:,:,:],a_p=a_p_mat[t,:,:,:],ystar=ystar_mat[t,:,:,:],c=c_mat[t,:,:,:],agrid=mm_new.agrid,parms=parms_base);
    end

    #Compute welfare cost at time=0
    Cequiv_avg_t0_wages[i] = Welfare_avg_CE(Wt_paths[1,:,:,:],gt[1,:,:,:],Wt_paths_base[1,:,:,:];mm_base,parms_base);

    ##########################
    # Rm only
    ##########################
    mm_new = deepcopy(mm_base);
    mm_new.Rm=Rmt[end];
    Wa_solve!(mm_new;tol=1e-5,py=mm_new.py,Rf=mm_new.Rf,Wa_p=mm_new.Wa_p,agrid=mm_new.agrid,parms=parms_base); #update steady state policies, Wa, and g
    Wt_paths[T+1,:,:,:] = W_solve_ss(zeros((Na,2,Nz));mm=mm_new,parms=parms_base);

    #Update transition within HH block
    Ydt, Aft, Amt, a_p_mat, am_p_mat, c_mat, ystar_mat, Wa_mat, Ct, Ct1, Ct0, Aft1, Aft0, Amt1, Amt0, gt = ha_block(pyt_base,Rft_base,Rmt,taut_base,wagest_base,thetat_base;parms=parms_base,mm=mm_base,mm_new=mm_new);

    #Compute transition in W 
    for t=T:-1:1
        Wt_paths[t,:,:,:]=W_iter(Wt_paths[t+1,:,:,:]; py_prime=mm_new.py,tau=mm_new.tau[1,1,1],wages_bar=mm_new.wages_bar,theta_prime=mm_new.theta,am_p=am_p_mat[t,:,:,:],a_p=a_p_mat[t,:,:,:],ystar=ystar_mat[t,:,:,:],c=c_mat[t,:,:,:],agrid=mm_new.agrid,parms=parms_base);
    end

    #Compute welfare cost at time=0
    Cequiv_avg_t0_Rm[i] = Welfare_avg_CE(Wt_paths[1,:,:,:],gt[1,:,:,:],Wt_paths_base[1,:,:,:];mm_base,parms_base);

    ##########################
    # Rf only
    ##########################
    mm_new = deepcopy(mm_base);
    mm_new.Rf=Rft[end];
    Wa_solve!(mm_new;tol=1e-5,py=mm_new.py,Rf=mm_new.Rf,Wa_p=mm_new.Wa_p,agrid=mm_new.agrid,parms=parms_base); #update steady state policies and Wa
    Wt_paths[T+1,:,:,:] = W_solve_ss(zeros((Na,2,Nz));mm=mm_new,parms=parms_base);

    #Update transition within HH block
    Ydt, Aft, Amt, a_p_mat, am_p_mat, c_mat, ystar_mat, Wa_mat, Ct, Ct1, Ct0, Aft1, Aft0, Amt1, Amt0, gt = ha_block(pyt_base,Rft,Rmt_base,taut_base,wagest_base,thetat_base;parms=parms_base,mm=mm_base,mm_new=mm_new);

    #Compute transition in W 
    for t=T:-1:1
        Wt_paths[t,:,:,:]=W_iter(Wt_paths[t+1,:,:,:]; py_prime=mm_new.py,tau=mm_new.tau[1,1,1],wages_bar=mm_new.wages_bar,theta_prime=mm_new.theta,am_p=am_p_mat[t,:,:,:],a_p=a_p_mat[t,:,:,:],ystar=ystar_mat[t,:,:,:],c=c_mat[t,:,:,:],agrid=mm_new.agrid,parms=parms_base)
    end

    #Compute welfare cost at time=0
    Cequiv_avg_t0_Rf[i] = Welfare_avg_CE(Wt_paths[1,:,:,:],gt[1,:,:,:],Wt_paths_base[1,:,:,:];mm_base,parms_base)

    ##########################
    # py only
    ##########################
    mm_new = deepcopy(mm_base);
    mm_new.py=pyt[end];
    Wa_solve!(mm_new;tol=1e-5,py=mm_new.py,Rf=mm_new.Rf,Wa_p=mm_new.Wa_p,agrid=mm_new.agrid,parms=parms_base); #update steady state policies and Wa
    Wt_paths[T+1,:,:,:] = W_solve_ss(zeros((Na,2,Nz));mm=mm_new,parms=parms_base);

    #Update transition within HH block
    Ydt, Aft, Amt, a_p_mat, am_p_mat, c_mat, ystar_mat, Wa_mat, Ct, Ct1, Ct0, Aft1, Aft0, Amt1, Amt0, gt = ha_block(pyt,Rft_base,Rmt_base,taut_base,wagest_base,thetat_base;parms=parms_base,mm=mm_base,mm_new=mm_new);

    #Compute transition in W 
    for t=T:-1:1
        Wt_paths[t,:,:,:]=W_iter(Wt_paths[t+1,:,:,:]; py_prime=pyt[t],tau=mm_new.tau[1,1,1],wages_bar=mm_new.wages_bar,theta_prime=mm_new.theta,am_p=am_p_mat[t,:,:,:],a_p=a_p_mat[t,:,:,:],ystar=ystar_mat[t,:,:,:],c=c_mat[t,:,:,:],agrid=mm_new.agrid,parms=parms_base)
    end

    #Compute welfare cost at time=0
    Cequiv_avg_t0_py[i] = Welfare_avg_CE(Wt_paths[1,:,:,:],gt[1,:,:,:],Wt_paths_base[1,:,:,:];mm_base,parms_base)

    ##########################
    # theta only
    ##########################
    mm_new = deepcopy(mm_base);
    mm_new.theta=thetat[end];
    Wa_solve!(mm_new;tol=1e-5,py=mm_new.py,Rf=mm_new.Rf,Wa_p=mm_new.Wa_p,agrid=mm_new.agrid,parms=parms_base); #update steady state policies and Wa
    Wt_paths[T+1,:,:,:] = W_solve_ss(zeros((Na,2,Nz));mm=mm_new,parms=parms_base);

    #Update transition within HH block
    Ydt, Aft, Amt, a_p_mat, am_p_mat, c_mat, ystar_mat, Wa_mat, Ct, Ct1, Ct0, Aft1, Aft0, Amt1, Amt0, gt = ha_block(pyt_base,Rft_base,Rmt_base,taut_base,wagest_base,thetat;parms=parms_base,mm=mm_base,mm_new=mm_new);

    #Compute transition in W 
    for t=T:-1:1
        Wt_paths[t,:,:,:]=W_iter(Wt_paths[t+1,:,:,:]; py_prime=mm_new.py,tau=mm_new.tau[1,1,1],wages_bar=mm_new.wages_bar,theta_prime=thetat[t],am_p=am_p_mat[t,:,:,:],a_p=a_p_mat[t,:,:,:],ystar=ystar_mat[t,:,:,:],c=c_mat[t,:,:,:],agrid=mm_new.agrid,parms=parms_base)
    end

    #Compute welfare cost at time=0
    Cequiv_avg_t0_theta[i] = Welfare_avg_CE(Wt_paths[1,:,:,:],gt[1,:,:,:],Wt_paths_base[1,:,:,:];mm_base,parms_base)

    save("../model_data_gwaste/time_paths_unanticipated_welfare_decomposition.jld2",
        "Cequiv_avg_t0_tau", Cequiv_avg_t0_tau,
        "Cequiv_avg_t0_wages", Cequiv_avg_t0_wages,
        "Cequiv_avg_t0_Rm", Cequiv_avg_t0_Rm,
        "Cequiv_avg_t0_Rf", Cequiv_avg_t0_Rf,
        "Cequiv_avg_t0_py", Cequiv_avg_t0_py,
        "Cequiv_avg_t0_theta", Cequiv_avg_t0_theta)
end

#plot Figure 9 (right panel)
dat=load("../model_data_gwaste/time_paths_unanticipated_welfare.jld2");
Cequiv_avg_t0 = dat["Cequiv_avg_t0"];
dat = load("../model_data_gwaste/time_paths_unanticipated_welfare_decomposition.jld2");
Cequiv_avg_t0_tau = dat["Cequiv_avg_t0_tau"];
Cequiv_avg_t0_wages = dat["Cequiv_avg_t0_wages"];
Cequiv_avg_t0_Rm = dat["Cequiv_avg_t0_Rm"];
Cequiv_avg_t0_Rf = dat["Cequiv_avg_t0_Rf"];
Cequiv_avg_t0_py = dat["Cequiv_avg_t0_py"];
Cequiv_avg_t0_theta = dat["Cequiv_avg_t0_theta"];

plot(inf_grid,(1.0.-Cequiv_avg_t0).*100,
    lw=5, label="Total",
    color=:black, line=:solid,
    title="Welfare cost of inflation",
    xlabel="Annual Inflation Rate (%)",
    ylabel="Consumption Equivalent (%) to π=0",
    tickfont=font(8,"serif"),
    guidefont=font(10,"serif"),
    legendfont=font(8,"serif"),
    fontfamily="serif",
    legend=:topleft,
    ylim=[-2,3.5],
    dpi=400)
hline!([0.0],lw=1,color=:black,label="")
plot!(inf_grid,(1.0.-Cequiv_avg_t0_Rm).*100,
    lw=3, label=L"Direct Effect, $R^m$", markershape=:circle, ms=6,
    color=:darkblue)
plot!(inf_grid,(1.0.-Cequiv_avg_t0_py).*100,
    lw=3, label=L"Price Effect, $p$", markershape=:rect, ms=6,
    color=:darkred)
plot!(inf_grid,(1.0.-Cequiv_avg_t0_Rf).*100,
    lw=3, label=L"Interest Rate Effect, $R^\iota$", markershape=:diamond, ms=6,
    color=:darkgreen)
plot!(inf_grid,(1.0.-Cequiv_avg_t0_wages).*100,
    lw=3, label=L"Wage Effect, $w^{e,j}$", markershape=:cross, ms=6,
    color=:darkviolet)  
plot!(inf_grid,(1.0.-Cequiv_avg_t0_theta).*100,
    lw=3, label=L"Emp. Risk Effect, $\theta$", markershape=:xcross, ms=6,
    color=:darkorange)
plot!(inf_grid,(1.0.-Cequiv_avg_t0_tau).*100,
    lw=3, label=L"Transfer Effect, $\tau$", markershape=:octagon, ms=6,
    color=:darkgray)
savefig("../figures/welfare_cost_decomposition_gwaste.pdf")


###############################################################################
###############################################################################
# Inflation and welfare in cross-section - Section 7.2
###############################################################################
###############################################################################

#initial steady state at pi=0%
parms_base=load("../model_data_gwaste/base_calib.jld2","parms_base");
mm_base=load(string("../model_data_gwaste/compstat_",1,".jld2"),"mm");
md_base=load(string("../model_data_gwaste/compstat_",1,".jld2"),"md");
mmts_base=load(string("../model_data_gwaste/compstat_",1,".jld2"),"mmts");

#initialize
Cequiv_avg_t0_tau_cs = zeros((11,2,Nz));
Cequiv_avg_t0_wages_cs = zeros((11,2,Nz));
Cequiv_avg_t0_Rm_cs = zeros((11,2,Nz));
Cequiv_avg_t0_Rf_cs = zeros((11,2,Nz));
Cequiv_avg_t0_py_cs = zeros((11,2,Nz));
Cequiv_avg_t0_theta_cs = zeros((11,2,Nz));
Cequiv_avg_t0_total_cs = zeros((11,2,Nz));
Wt_paths = zeros((T+1,Na,2,Nz)); 

#Compute undistored value function in steady state (used as initial guess in Wel_diff)
Wbase0_Delt0 = W_solve_Delt(zeros((Na,2,Nz)),1.0; mm=mm_base, parms=parms_base);

#Redefine perctiles of total wealth in baseline (pi=0%) steady state
pctiles=range(0,1,step=.1);
pctind=[1]
for i=2:length(pctiles)
   push!(pctind,searchsortedlast(mmts_base.G,pctiles[i]))
end

#load transitional dynamics from initial steady state (pi=0%, index i=1) to terminal steady state (pi=10%, index i=14)
dat=load("../model_data_gwaste/time_paths_unanticipated.jld2");
Rmt = dat["Rmt_paths"][14,:];
Rft = dat["Rft_paths"][14,:];
pyt = dat["pyt_paths"][14,:];
taut = dat["taut_paths"][14,:,:];
taut=taut';
wagest = dat["wagest_paths"][14,:,:,:];
thetat = dat["thetat_paths"][14,:];

#define counter-factual time series (at initial steady-state values)
Rmt_base = ones(T).*mm_base.Rm;
Rft_base = ones(T).*mm_base.Rf;
pyt_base = ones(T).*mm_base.py;
taut_base = ones((T,2)).*mm_base.tau[1,1,1];
wagest_base = ones((2,Nz,T)); wagest_base.=mm_base.wages_bar;
thetat_base = ones(T).*mm_base.theta;


###############################################################################
# Total Welfare Cost in cross-section - Figure 11 (left panel)
###############################################################################
#Resolve for terminal utility at new steady state
mm_new = deepcopy(mm_base);
mm_new.tau.=taut[300,1]; 
mm_new.wages_bar.=wagest[:,:,300];
mm_new.Rm=Rmt[300];
mm_new.Rf=Rft[300];
mm_new.py=pyt[300];
mm_new.theta=thetat[300];
Wa_solve!(mm_new;tol=1e-10,py=mm_new.py,Rf=mm_new.Rf,Wa_p=mm_new.Wa_p,agrid=mm_new.agrid,parms=parms_base); #update steady state policies and Wa
Wt_paths[T+1,:,:,:] = W_solve_ss(zeros((Na,2,Nz));mm=mm_new,parms=parms_base); #compute steady state W(a)

#Update transition within HH block
Ydt, Aft, Amt, a_p_mat, am_p_mat, c_mat, ystar_mat, Wa_mat, Ct, Ct1, Ct0, Aft1, Aft0, Amt1, Amt0, gt = ha_block(pyt,Rft,Rmt,taut,wagest,thetat;parms=parms_base,mm=mm_base,mm_new=mm_new);

#Compute transition in utility, W
for t=T:-1:1
    Wt_paths[t,:,:,:]=W_iter(Wt_paths[t+1,:,:,:]; py_prime=pyt[t],tau=taut[t,1],wages_bar=wagest[:,:,t],theta_prime=thetat[t],am_p=am_p_mat[t,:,:,:],a_p=a_p_mat[t,:,:,:],ystar=ystar_mat[t,:,:,:],c=c_mat[t,:,:,:],agrid=mm_new.agrid,parms=parms_base);
end

#Compute welfare cost at time=0
Threads.@threads for j=1:11
    for l=1:2
        for k=1:Nz
            Cequiv_avg_t0_total_cs[j,l,k] = Welfare_bystate_CE(Wt_paths[1,pctind[j],l,k],Wbase0_Delt0,pctind[j],l,k;mm_base,parms_base);
        end
    end
end

#Save output
save("../model_data_gwaste/time_paths_unanticipated_welfare_cross_section.jld2",
    "Cequiv_avg_t0_total_cs", Cequiv_avg_t0_total_cs)

#Plot Figure 11 (left panel)
dat=load("../model_data_gwaste/time_paths_unanticipated_welfare_cross_section.jld2")
Cequiv_avg_t0_total_cs = dat["Cequiv_avg_t0_total_cs"];
p1=plot(pctiles*100,(1.0.-Cequiv_avg_t0_total_cs[:,1,1]).*100,
    lw=3, label="",
    color=:darkblue, line=:solid,
    title="Welfare cost of 10% inflation",
    xlabel="Total wealth percentile (%)",
    ylabel="Consumption Equivalent (%) to π=0",
    tickfont=font(12,"serif"),
    guidefont=font(14,"serif"),
    legendfont=font(13,"serif"),
    fontfamily="serif",
    legend=:topleft,
    markershape=:circle,
    markersize=6,
    markercolor=:darkblue,
    dpi=400)
plot!(pctiles*100,(1.0.-Cequiv_avg_t0_total_cs[:,2,1]).*100,
    lw=3, label="",markershape=:circle,
    markersize=6,
    markercolor=:darkblue,
    color=:darkblue, line=:dash)
plot!(pctiles*100,(1.0.-Cequiv_avg_t0_total_cs[:,1,2]).*100,
    lw=3, label="",markershape=:circle,
    markersize=6,
    markercolor=:darkgreen,
    color=:darkgreen, line=:solid)
plot!(pctiles*100,(1.0.-Cequiv_avg_t0_total_cs[:,2,2]).*100,
    lw=3, label="",markershape=:circle,
    markersize=6,
    markercolor=:darkgreen,
    color=:darkgreen, line=:dash)
plot!(pctiles*100,(1.0.-Cequiv_avg_t0_total_cs[:,1,3]).*100,
    lw=3, label="",markershape=:circle,
    markersize=6,
    markercolor=:black,
    color=:black, line=:solid)
plot!(pctiles*100,(1.0.-Cequiv_avg_t0_total_cs[:,2,3]).*100,
    lw=3, label="",markershape=:circle,
    markersize=6,
    markercolor=:black,
    color=:black, line=:dash)
annotate!(38,3.9,text("high-income",12,:black,:serif))
annotate!(38,2.2,text("middle-income",12,:darkgreen,:serif))
annotate!(38,1.2,text("low-income",12,:darkblue,:serif))
annotate!(50,5.05,text("(solid=employed, dash=unemployed)",10,:black,:serif))
savefig("../figures/welfare_cost_by_wealth_and_income_gwaste.pdf")

###############################################################################
# Welfare Cost in cross-section - Transfers (tau) only - Figure 12 (bottom-left panel)
###############################################################################

#Solve for counter-factual terminal steady state utility
mm_new = deepcopy(mm_base);
mm_new.tau.=taut[300,1]; 
Wa_solve!(mm_new;tol=1e-10,py=mm_new.py,Rf=mm_new.Rf,Wa_p=mm_new.Wa_p,agrid=mm_new.agrid,parms=parms_base); #update steady state policies and Wa
Wt_paths[T+1,:,:,:] = W_solve_ss(zeros((Na,2,Nz));mm=mm_new,parms=parms_base); #compute steady state W(a)

#Update transition within HH block
Ydt, Aft, Amt, a_p_mat, am_p_mat, c_mat, ystar_mat, Wa_mat, Ct, Ct1, Ct0, Aft1, Aft0, Amt1, Amt0, gt = ha_block(pyt_base,Rft_base,Rmt_base,taut,wagest_base,thetat_base;parms=parms_base,mm=mm_base,mm_new=mm_new);

#Compute transition in utilty, W, to solve for time zero
for t=T:-1:1
    Wt_paths[t,:,:,:]=W_iter(Wt_paths[t+1,:,:,:]; py_prime=mm_new.py,tau=taut[t,1],wages_bar=mm_new.wages_bar,theta_prime=mm_new.theta,am_p=am_p_mat[t,:,:,:],a_p=a_p_mat[t,:,:,:],ystar=ystar_mat[t,:,:,:],c=c_mat[t,:,:,:],agrid=mm_new.agrid,parms=parms_base);
end

#Compute welfare cost at time=0
Threads.@threads for j=1:11
    for l=1:2
        for k=1:Nz
            Cequiv_avg_t0_tau_cs[j,l,k] = Welfare_bystate_CE(Wt_paths[1,pctind[j],l,k],Wbase0_Delt0,pctind[j],l,k;mm_base,parms_base);
            print("\nWelcost=",round((1.0.-Cequiv_avg_t0_tau_cs[j,l,k])*100,digits=4), "%, for (j,l,k)=(",j,",",l,",",k,")")
        end
    end
end

###############################################################################
# Welfare Cost in cross-section - wages only - Figure 12 (bottom-right panel)
###############################################################################

#Solve for counter-factual terminal steady state utility
mm_new = deepcopy(mm_base);
mm_new.wages_bar.=wagest[:,:,300]; 
Wa_solve!(mm_new;tol=1e-10,py=mm_new.py,Rf=mm_new.Rf,Wa_p=mm_new.Wa_p,agrid=mm_new.agrid,parms=parms_base); #update steady state policies and Wa
Wt_paths[T+1,:,:,:] = W_solve_ss(zeros((Na,2,Nz));mm=mm_new,parms=parms_base);

#Update transition within HH block
Ydt, Aft, Amt, a_p_mat, am_p_mat, c_mat, ystar_mat, Wa_mat, Ct, Ct1, Ct0, Aft1, Aft0, Amt1, Amt0, gt = ha_block(pyt_base,Rft_base,Rmt_base,taut_base,wagest,thetat_base;parms=parms_base,mm=mm_base,mm_new=mm_new);

#Compute transition in utilty, W, to solve for time zero
for t=T:-1:1
    Wt_paths[t,:,:,:]=W_iter(Wt_paths[t+1,:,:,:]; py_prime=mm_new.py,tau=mm_new.tau[1,1,1],wages_bar=wagest[:,:,t],theta_prime=mm_new.theta,am_p=am_p_mat[t,:,:,:],a_p=a_p_mat[t,:,:,:],ystar=ystar_mat[t,:,:,:],c=c_mat[t,:,:,:],agrid=mm_new.agrid,parms=parms_base);
end

#Compute welfare cost at time=0
Threads.@threads for j=1:11
    for l=1:2
        for k=1:Nz
            Cequiv_avg_t0_wages_cs[j,l,k] = Welfare_bystate_CE(Wt_paths[1,pctind[j],l,k],Wbase0_Delt0,pctind[j],l,k;mm_base,parms_base);
        end
    end
end

###############################################################################
# Welfare Cost in cross-section - Rm only - Figure 12 (top-left panel)
###############################################################################

#Solve for counter-factual terminal steady state utility
mm_new = deepcopy(mm_base);
mm_new.Rm=Rmt[300];
Wa_solve!(mm_new;tol=1e-10,py=mm_new.py,Rf=mm_new.Rf,Wa_p=mm_new.Wa_p,agrid=mm_new.agrid,parms=parms_base); #update steady state policies, Wa, and g
Wt_paths[T+1,:,:,:] = W_solve_ss(zeros((Na,2,Nz));mm=mm_new,parms=parms_base);

#Update transition within HH block
Ydt, Aft, Amt, a_p_mat, am_p_mat, c_mat, ystar_mat, Wa_mat, Ct, Ct1, Ct0, Aft1, Aft0, Amt1, Amt0, gt = ha_block(pyt_base,Rft_base,Rmt,taut_base,wagest_base,thetat_base;parms=parms_base,mm=mm_base,mm_new=mm_new);

#Compute transition in utilty, W, to solve for time zero
for t=T:-1:1
    Wt_paths[t,:,:,:]=W_iter(Wt_paths[t+1,:,:,:]; py_prime=mm_new.py,tau=mm_new.tau[1,1,1],wages_bar=mm_new.wages_bar,theta_prime=mm_new.theta,am_p=am_p_mat[t,:,:,:],a_p=a_p_mat[t,:,:,:],ystar=ystar_mat[t,:,:,:],c=c_mat[t,:,:,:],agrid=mm_new.agrid,parms=parms_base);
end

#Compute welfare cost at time=0
Threads.@threads for j=1:11
    for l=1:2
        for k=1:Nz
            Cequiv_avg_t0_Rm_cs[j,l,k] = Welfare_bystate_CE(Wt_paths[1,pctind[j],l,k],Wbase0_Delt0,pctind[j],l,k;mm_base,parms_base);
        end
    end
end


###############################################################################
# Welfare Cost in cross-section - Rf only - Figure 12 (top-right panel)
###############################################################################

#Solve for counter-factual terminal steady-state utility
mm_new = deepcopy(mm_base);
mm_new.Rf=Rft[300];
Wa_solve!(mm_new;tol=1e-10,py=mm_new.py,Rf=mm_new.Rf,Wa_p=mm_new.Wa_p,agrid=mm_new.agrid,parms=parms_base); #update steady state policies and Wa
Wt_paths[T+1,:,:,:] = W_solve_ss(zeros((Na,2,Nz));mm=mm_new,parms=parms_base);

#Update transition within HH block
Ydt, Aft, Amt, a_p_mat, am_p_mat, c_mat, ystar_mat, Wa_mat, Ct, Ct1, Ct0, Aft1, Aft0, Amt1, Amt0, gt = ha_block(pyt_base,Rft,Rmt_base,taut_base,wagest_base,thetat_base;parms=parms_base,mm=mm_base,mm_new=mm_new);

#Compute transition in utilty, W, to solve for time zero
for t=T:-1:1
    Wt_paths[t,:,:,:]=W_iter(Wt_paths[t+1,:,:,:]; py_prime=mm_new.py,tau=mm_new.tau[1,1,1],wages_bar=mm_new.wages_bar,theta_prime=mm_new.theta,am_p=am_p_mat[t,:,:,:],a_p=a_p_mat[t,:,:,:],ystar=ystar_mat[t,:,:,:],c=c_mat[t,:,:,:],agrid=mm_new.agrid,parms=parms_base)
end

#Compute welfare cost at time=0
Threads.@threads for j=1:11
    for l=1:2
        for k=1:Nz
            Cequiv_avg_t0_Rf_cs[j,l,k] = Welfare_bystate_CE(Wt_paths[1,pctind[j],l,k],Wbase0_Delt0,pctind[j],l,k;mm_base,parms_base);
        end
    end
end

###############################################################################
# Welfare Cost in cross-section - py only - Figure 12 (top-middle panel)
###############################################################################

#Compute welfare in ss under new py
mm_new = deepcopy(mm_base);
mm_new.py=pyt[300];
Wa_solve!(mm_new;tol=1e-10,py=mm_new.py,Rf=mm_new.Rf,Wa_p=mm_new.Wa_p,agrid=mm_new.agrid,parms=parms_base); #update steady state policies and Wa
Wt_paths[T+1,:,:,:] = W_solve_ss(zeros((Na,2,Nz));mm=mm_new,parms=parms_base);

#Update transition within HH block
Ydt, Aft, Amt, a_p_mat, am_p_mat, c_mat, ystar_mat, Wa_mat, Ct, Ct1, Ct0, Aft1, Aft0, Amt1, Amt0, gt = ha_block(pyt,Rft_base,Rmt_base,taut_base,wagest_base,thetat_base;parms=parms_base,mm=mm_base,mm_new=mm_new);

#Compute transition in utilty, W, to solve for time zero
for t=T:-1:1
    Wt_paths[t,:,:,:]=W_iter(Wt_paths[t+1,:,:,:]; py_prime=pyt[t],tau=mm_new.tau[1,1,1],wages_bar=mm_new.wages_bar,theta_prime=mm_new.theta,am_p=am_p_mat[t,:,:,:],a_p=a_p_mat[t,:,:,:],ystar=ystar_mat[t,:,:,:],c=c_mat[t,:,:,:],agrid=mm_new.agrid,parms=parms_base)
end

#Compute welfare cost at time=0
Threads.@threads for j=1:11
    for l=1:2
        for k=1:Nz
            Cequiv_avg_t0_py_cs[j,l,k] = Welfare_bystate_CE(Wt_paths[1,pctind[j],l,k],Wbase0_Delt0,pctind[j],l,k;mm_base,parms_base);
        end
    end
end

###############################################################################
# Welfare Cost in cross-section - theta only - Figure 12 (bottom-middle panel)
###############################################################################

#Compute welfare in ss under new theta
mm_new = deepcopy(mm_base);
mm_new.theta=thetat[300];
Wa_solve!(mm_new;tol=1e-10,py=mm_new.py,Rf=mm_new.Rf,Wa_p=mm_new.Wa_p,agrid=mm_new.agrid,parms=parms_base); #update steady state policies and Wa
Wt_paths[T+1,:,:,:] = W_solve_ss(zeros((Na,2,Nz));mm=mm_new,parms=parms_base);

#Update transition within HH block
Ydt, Aft, Amt, a_p_mat, am_p_mat, c_mat, ystar_mat, Wa_mat, Ct, Ct1, Ct0, Aft1, Aft0, Amt1, Amt0, gt = ha_block(pyt_base,Rft_base,Rmt_base,taut_base,wagest_base,thetat;parms=parms_base,mm=mm_base,mm_new=mm_new);

#Compute transition in utilty, W, to solve for time zero
for t=T:-1:1
    Wt_paths[t,:,:,:]=W_iter(Wt_paths[t+1,:,:,:]; py_prime=mm_new.py,tau=mm_new.tau[1,1,1],wages_bar=mm_new.wages_bar,theta_prime=thetat[t],am_p=am_p_mat[t,:,:,:],a_p=a_p_mat[t,:,:,:],ystar=ystar_mat[t,:,:,:],c=c_mat[t,:,:,:],agrid=mm_new.agrid,parms=parms_base)
end

#Compute welfare cost at time=0
Threads.@threads for j=1:11
    for l=1:2
        for k=1:Nz
            Cequiv_avg_t0_theta_cs[j,l,k] = Welfare_bystate_CE(Wt_paths[1,pctind[j],l,k],Wbase0_Delt0,pctind[j],l,k;mm_base,parms_base);
        end
    end
end

save("../model_data_gwaste/time_paths_unanticipated_welfare_decomposition_cross_section.jld2",
    "Cequiv_avg_t0_tau_cs", Cequiv_avg_t0_tau_cs,
    "Cequiv_avg_t0_wages_cs", Cequiv_avg_t0_wages_cs,
    "Cequiv_avg_t0_Rm_cs", Cequiv_avg_t0_Rm_cs,
    "Cequiv_avg_t0_Rf_cs", Cequiv_avg_t0_Rf_cs,
    "Cequiv_avg_t0_py_cs", Cequiv_avg_t0_py_cs,
    "Cequiv_avg_t0_theta_cs", Cequiv_avg_t0_theta_cs)

###############################################################################
# Plot Figure 12
###############################################################################
dat = load("../model_data_gwaste/time_paths_unanticipated_welfare_decomposition_cross_section.jld2");
Cequiv_avg_t0_tau_cs = dat["Cequiv_avg_t0_tau_cs"];
Cequiv_avg_t0_wages_cs = dat["Cequiv_avg_t0_wages_cs"];
Cequiv_avg_t0_Rm_cs = dat["Cequiv_avg_t0_Rm_cs"];
Cequiv_avg_t0_Rf_cs = dat["Cequiv_avg_t0_Rf_cs"];
Cequiv_avg_t0_py_cs = dat["Cequiv_avg_t0_py_cs"];
Cequiv_avg_t0_theta_cs = dat["Cequiv_avg_t0_theta_cs"];

p1=plot(pctiles*100,(1.0.-Cequiv_avg_t0_Rm_cs[:,1,1]).*100,
    lw=3, label="",
    color=:darkblue, line=:solid,
    title=L"Welfare cost - $R^m$ only",
    xlabel="Total wealth percentile (%)",
    ylabel="Consumption Equivalent (%) to π=0",
    tickfont=font(12,"serif"),
    guidefont=font(14,"serif"),
    legendfont=font(12,"serif"),
    fontfamily="serif",
    legend=:topleft,
    markershape=:circle,
    markersize=10,
    markercolor=:darkblue,
    bottommargin=(200,:px),
    rightmargin=(150,:px),
    dpi=400)
hline!([0.0],lw=1,color=:black,label="")
plot!(pctiles*100,(1.0.-Cequiv_avg_t0_Rm_cs[:,2,1]).*100,
    lw=3, label="",markershape=:circle,
    markersize=10,
    markercolor=:darkblue,
    color=:darkblue, line=:dash)
plot!(pctiles*100,(1.0.-Cequiv_avg_t0_Rm_cs[:,1,2]).*100,
    lw=3, label="",markershape=:rect,
    markersize=10,
    markercolor=:darkgreen,
    color=:darkgreen, line=:solid)
plot!(pctiles*100,(1.0.-Cequiv_avg_t0_Rm_cs[:,2,2]).*100,
    lw=3, label="",markershape=:rect,
    markersize=10,
    markercolor=:darkgreen,
    color=:darkgreen, line=:dash)
plot!(pctiles*100,(1.0.-Cequiv_avg_t0_Rm_cs[:,1,3]).*100,
    lw=3, label="",markershape=:xcross,
    markersize=10,
    markercolor=:black,
    color=:black, line=:solid)
plot!(pctiles*100,(1.0.-Cequiv_avg_t0_Rm_cs[:,2,3]).*100,
    lw=3, label="",markershape=:xcross,
    markersize=10,
    markercolor=:black,
    color=:black, line=:dash)
annotate!(38,3.8,text("high-income",12,:black,:serif))
annotate!(38,2.0,text("middle-income",12,:darkgreen,:serif))
annotate!(38,0.8,text("low-income",12,:darkblue,:serif))
annotate!(50,5.2,text("(solid=employed, dash=unemployed)",12,:black,:serif))

p2=plot(pctiles*100,(1.0.-Cequiv_avg_t0_py_cs[:,1,1]).*100,
    lw=3, label="",
    color=:darkblue, line=:solid,
    title=L"Welfare cost - $p$ only",
    xlabel="Total wealth percentile (%)",
    ylabel="Consumption Equivalent (%) to π=0",
    tickfont=font(12,"serif"),
    guidefont=font(14,"serif"),
    legendfont=font(12,"serif"),
    fontfamily="serif",
    markershape=:circle,
    markersize=10,
    markercolor=:darkblue,
    legend=:topleft,
    bottommargin=(200,:px),
    rightmargin=(150,:px),
    dpi=400)
hline!([0.0],lw=1,color=:black,label="")
plot!(pctiles*100,(1.0.-Cequiv_avg_t0_py_cs[:,2,1]).*100,
    lw=3, label="",markershape=:circle,
    markersize=10,
    markercolor=:darkblue,
    color=:darkblue, line=:dash)
plot!(pctiles*100,(1.0.-Cequiv_avg_t0_py_cs[:,1,2]).*100,
    lw=3, label="",markershape=:rect,
    markersize=10,
    markercolor=:darkgreen,
    color=:darkgreen, line=:solid)
plot!(pctiles*100,(1.0.-Cequiv_avg_t0_py_cs[:,2,2]).*100,
    lw=3, label="",markershape=:rect,
    markersize=10,
    markercolor=:darkgreen,
    color=:darkgreen, line=:dash)
plot!(pctiles*100,(1.0.-Cequiv_avg_t0_py_cs[:,1,3]).*100,
    lw=3, label="",markershape=:xcross,
    markersize=10,
    markercolor=:black,
    color=:black, line=:solid)
plot!(pctiles*100,(1.0.-Cequiv_avg_t0_py_cs[:,2,3]).*100,
    lw=3, label="",markershape=:xcross,
    markersize=10,
    markercolor=:black,
    color=:black, line=:dash)
#

p3=plot(pctiles*100,(1.0.-Cequiv_avg_t0_Rf_cs[:,1,1]).*100,
    lw=3, label="",
    color=:darkblue, line=:solid,
    title=L"Welfare cost - $R^\iota$ only",
    xlabel="Total wealth percentile (%)",
    ylabel="Consumption Equivalent (%) to π=0",
    tickfont=font(12,"serif"),
    guidefont=font(14,"serif"),
    legendfont=font(12,"serif"),
    fontfamily="serif",
    markershape=:circle,
    markersize=10,
    markercolor=:darkblue,
    legend=:topleft,
    bottommargin=(200,:px),
    dpi=400)
hline!([0.0],lw=1,color=:black,label="")
plot!(pctiles*100,(1.0.-Cequiv_avg_t0_Rf_cs[:,2,1]).*100,
    lw=3, label="",markershape=:circle,
    markersize=10,
    markercolor=:darkblue,
    color=:darkblue, line=:dash)
plot!(pctiles*100,(1.0.-Cequiv_avg_t0_Rf_cs[:,1,2]).*100,
    lw=3, label="",markershape=:rect,
    markersize=10,
    markercolor=:darkgreen,
    color=:darkgreen, line=:solid)
plot!(pctiles*100,(1.0.-Cequiv_avg_t0_Rf_cs[:,2,2]).*100,
    lw=3, label="",markershape=:rect,
    markersize=10,
    markercolor=:darkgreen,
    color=:darkgreen, line=:dash)
plot!(pctiles*100,(1.0.-Cequiv_avg_t0_Rf_cs[:,1,3]).*100,
    lw=3, label="",markershape=:xcross,
    markersize=10,
    markercolor=:black,
    color=:black, line=:solid)
plot!(pctiles*100,(1.0.-Cequiv_avg_t0_Rf_cs[:,2,3]).*100,
    lw=3, label="",markershape=:xcross,
    markersize=10,
    markercolor=:black,
    color=:black, line=:dash)
#

p4=plot(pctiles*100,(1.0.-Cequiv_avg_t0_tau_cs[:,1,1]).*100,
    lw=3, label="",
    color=:darkblue, line=:solid,
    title=L"Welfare cost - $\tau$ only",
    xlabel="Total wealth percentile (%)",
    ylabel="Consumption Equivalent (%) to π=0",
    tickfont=font(12,"serif"),
    guidefont=font(14,"serif"),
    legendfont=font(12,"serif"),
    fontfamily="serif",
    legend=:topleft,
    markershape=:circle,
    markersize=10,
    markercolor=:darkblue,
    rightmargin=(150,:px),
    dpi=400)
hline!([0.0],lw=1,color=:black,label="")
plot!(pctiles*100,(1.0.-Cequiv_avg_t0_tau_cs[:,2,1]).*100,
    lw=3, label="",
    color=:darkblue, 
    line=:dash,
    markershape=:circle,
    markersize=10,
    markercolor=:darkblue)
plot!(pctiles*100,(1.0.-Cequiv_avg_t0_tau_cs[:,1,2]).*100,
    lw=3, label="",
    color=:darkgreen, 
    markershape=:rect,
    markersize=10,
    markercolor=:darkgreen,
    line=:solid)
plot!(pctiles*100,(1.0.-Cequiv_avg_t0_tau_cs[:,2,2]).*100,
    lw=3, label="",    
    markershape=:rect,
    markersize=10,
    markercolor=:darkgreen,
    color=:darkgreen, line=:dash)
plot!(pctiles*100,(1.0.-Cequiv_avg_t0_tau_cs[:,1,3]).*100,
    lw=3, label="",    markershape=:xcross,
    markersize=10,
    markercolor=:black,
    color=:black, line=:solid)
plot!(pctiles*100,(1.0.-Cequiv_avg_t0_tau_cs[:,2,3]).*100,
    lw=3, label="",    markershape=:xcross,
    markersize=10,
    markercolor=:black,
    color=:black, line=:dash)
#

p5=plot(pctiles*100,(1.0.-Cequiv_avg_t0_theta_cs[:,1,1]).*100,
    lw=3, label="",
    color=:darkblue, line=:solid,
    title=L"Welfare cost - $\theta$ only",
    xlabel="Total wealth percentile (%)",
    ylabel="Consumption Equivalent (%) to π=0",
    tickfont=font(12,"serif"),
    guidefont=font(14,"serif"),
    legendfont=font(12,"serif"),
    fontfamily="serif",
    markershape=:circle,
    markersize=10,
    markercolor=:darkblue,
    legend=:topleft,
    rightmargin=(150,:px),
    dpi=400)
hline!([0.0],lw=1,color=:black,label="")
plot!(pctiles*100,(1.0.-Cequiv_avg_t0_theta_cs[:,2,1]).*100,
    lw=3, label="",markershape=:circle,
    markersize=10,
    markercolor=:darkblue,
    color=:darkblue, line=:dash)
plot!(pctiles*100,(1.0.-Cequiv_avg_t0_theta_cs[:,1,2]).*100,
    lw=3, label="",markershape=:rect,
    markersize=10,
    markercolor=:darkgreen,
    color=:darkgreen, line=:solid)
plot!(pctiles*100,(1.0.-Cequiv_avg_t0_theta_cs[:,2,2]).*100,
    lw=3, label="",markershape=:rect,
    markersize=10,
    markercolor=:darkgreen,
    color=:darkgreen, line=:dash)
plot!(pctiles*100,(1.0.-Cequiv_avg_t0_theta_cs[:,1,3]).*100,
    lw=3, label="",markershape=:xcross,
    markersize=10,
    markercolor=:black,
    color=:black, line=:solid)
plot!(pctiles*100,(1.0.-Cequiv_avg_t0_theta_cs[:,2,3]).*100,
    lw=3, label="",markershape=:xcross,
    markersize=10,
    markercolor=:black,
    color=:black, line=:dash)
#

p6=plot(pctiles*100,(1.0.-Cequiv_avg_t0_wages_cs[:,1,1]).*100,
    lw=3, label="",
    color=:darkblue, line=:solid,
    title="Welfare cost - wages only",
    xlabel="Total wealth percentile (%)",
    ylabel="Consumption Equivalent (%) to π=0",
    tickfont=font(12,"serif"),
    guidefont=font(14,"serif"),
    legendfont=font(12,"serif"),
    fontfamily="serif",
    legend=:topleft,
    markershape=:circle,
    markersize=10,
    markercolor=:darkblue,
    dpi=400)
plot!(pctiles*100,(1.0.-Cequiv_avg_t0_wages_cs[:,2,1]).*100,
    lw=3, label="",markershape=:circle,
    markersize=10,
    markercolor=:darkblue,
    color=:darkblue, line=:dash)
plot!(pctiles*100,(1.0.-Cequiv_avg_t0_wages_cs[:,1,2]).*100,
    lw=3, label="",markershape=:rect,
    markersize=10,
    markercolor=:darkgreen,
    color=:darkgreen, line=:solid)
plot!(pctiles*100,(1.0.-Cequiv_avg_t0_wages_cs[:,2,2]).*100,
    lw=3, label="",markershape=:rect,
    markersize=10,
    markercolor=:darkgreen,
    color=:darkgreen, line=:dash)
plot!(pctiles*100,(1.0.-Cequiv_avg_t0_wages_cs[:,1,3]).*100,
    lw=3, label="",markershape=:xcross,
    markersize=10,
    markercolor=:black,
    color=:black, line=:solid)
plot!(pctiles*100,(1.0.-Cequiv_avg_t0_wages_cs[:,2,3]).*100,
    lw=3, label="",markershape=:xcross,
    markersize=10,
    markercolor=:black,
    color=:black, line=:dash)
#

plot(p1,p2,p3,p4,p5,p6,layout=(2,3),size=(1500,800))

savefig("../figures/welfare_cost_by_wealth_and_income_decomp_gwaste.pdf")

#END